{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWD-LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_12 import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=6317)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = datasets.untar_data(datasets.URLs.IMDB)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have to preprocess the data again to pickle it because if we try to load the previous `SplitLabeledData` with pickle, it will complain some of the functions aren't in main."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "il = TextList.from_files(path, include=['train', 'test', 'unsup'])\n",
    "sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc_tok,proc_num = TokenizeProcessor(max_workers=8),NumericalizeProcessor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok,proc_num])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(ll, open(path/'ll_lm.pkl', 'wb'))\n",
    "pickle.dump(proc_num.vocab, open(path/'vocab_lm.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll = pickle.load(open(path/'ll_lm.pkl', 'rb'))\n",
    "vocab = pickle.load(open(path/'vocab_lm.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs,bptt = 64,70\n",
    "data = lm_databunchify(ll, bs, bptt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AWD-LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before explaining what an AWD LSTM is, we need to start with an LSTM. RNNs were covered in part 1, if you need a refresher, there is a great visualization of them on [this website](http://joshvarty.github.io/VisualizingRNNs/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=6330)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LSTM from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to implement those equations (where $\\sigma$ stands for sigmoid):\n",
    "\n",
    "![LSTM cell and equations](images/lstm.jpg)\n",
    "(picture from [Understanding LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) by Chris Olah.)\n",
    "\n",
    "If we want to take advantage of our GPU, it's better to do one big matrix multiplication than four smaller ones. So we compute the values of the four gates all at once. Since there is a matrix multiplication and a bias, we use `nn.Linear` to do it.\n",
    "\n",
    "We need two linear layers: one for the input and one for the hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(nn.Module):\n",
    "    def __init__(self, ni, nh):\n",
    "        super().__init__()\n",
    "        self.ih = nn.Linear(ni,4*nh)\n",
    "        self.hh = nn.Linear(nh,4*nh)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        h,c = state\n",
    "        #One big multiplication for all the gates is better than 4 smaller ones\n",
    "        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)\n",
    "        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])\n",
    "        cellgate = gates[3].tanh()\n",
    "\n",
    "        c = (forgetgate*c) + (ingate*cellgate)\n",
    "        h = outgate * c.tanh()\n",
    "        return h, (h,c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then an LSTM layer just applies the cell on all the time steps in order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMLayer(nn.Module):\n",
    "    def __init__(self, cell, *cell_args):\n",
    "        super().__init__()\n",
    "        self.cell = cell(*cell_args)\n",
    "\n",
    "    def forward(self, input, state):\n",
    "        inputs = input.unbind(1)\n",
    "        outputs = []\n",
    "        for i in range(len(inputs)):\n",
    "            out, state = self.cell(inputs[i], state)\n",
    "            outputs += [out]\n",
    "        return torch.stack(outputs, dim=1), state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's try it out and see how fast we are. We only measure the forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = LSTMLayer(LSTMCell, 300, 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(64, 70, 300)\n",
    "h = (torch.zeros(64, 300),torch.zeros(64, 300))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 y,h1 = lstm(x,h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = lstm.cuda()\n",
    "x = x.cuda()\n",
    "h = (h[0].cuda(), h[1].cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def time_fn(f):\n",
    "    f()\n",
    "    torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CUDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = partial(lstm,x,h)\n",
    "time_fn(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 time_fn(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Builtin version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's compare with PyTorch!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=6718)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = nn.LSTM(300, 300, 1, batch_first=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(64, 70, 300)\n",
    "h = (torch.zeros(1, 64, 300),torch.zeros(1, 64, 300))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 y,h1 = lstm(x,h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = lstm.cuda()\n",
    "x = x.cuda()\n",
    "h = (h[0].cuda(), h[1].cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = partial(lstm,x,h)\n",
    "time_fn(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 time_fn(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So our version is running at almost the same speed on the CPU. However, on the GPU, PyTorch uses CuDNN behind the scenes that optimizes greatly the for loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Jit version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=6744)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.jit as jit\n",
    "from torch import Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have to write everything from scratch to be a bit faster, so we don't use the linear layers here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMCell(jit.ScriptModule):\n",
    "    def __init__(self, ni, nh):\n",
    "        super().__init__()\n",
    "        self.ni = ni\n",
    "        self.nh = nh\n",
    "        self.w_ih = nn.Parameter(torch.randn(4 * nh, ni))\n",
    "        self.w_hh = nn.Parameter(torch.randn(4 * nh, nh))\n",
    "        self.bias_ih = nn.Parameter(torch.randn(4 * nh))\n",
    "        self.bias_hh = nn.Parameter(torch.randn(4 * nh))\n",
    "\n",
    "    @jit.script_method\n",
    "    def forward(self, input:Tensor, state:Tuple[Tensor, Tensor])->Tuple[Tensor, Tuple[Tensor, Tensor]]:\n",
    "        hx, cx = state\n",
    "        gates = (input @ self.w_ih.t() + self.bias_ih +\n",
    "                 hx @ self.w_hh.t() + self.bias_hh)\n",
    "        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)\n",
    "\n",
    "        ingate = torch.sigmoid(ingate)\n",
    "        forgetgate = torch.sigmoid(forgetgate)\n",
    "        cellgate = torch.tanh(cellgate)\n",
    "        outgate = torch.sigmoid(outgate)\n",
    "\n",
    "        cy = (forgetgate * cx) + (ingate * cellgate)\n",
    "        hy = outgate * torch.tanh(cy)\n",
    "\n",
    "        return hy, (hy, cy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMLayer(jit.ScriptModule):\n",
    "    def __init__(self, cell, *cell_args):\n",
    "        super().__init__()\n",
    "        self.cell = cell(*cell_args)\n",
    "\n",
    "    @jit.script_method\n",
    "    def forward(self, input:Tensor, state:Tuple[Tensor, Tensor])->Tuple[Tensor, Tuple[Tensor, Tensor]]:\n",
    "        inputs = input.unbind(1)\n",
    "        outputs = []\n",
    "        for i in range(len(inputs)):\n",
    "            out, state = self.cell(inputs[i], state)\n",
    "            outputs += [out]\n",
    "        return torch.stack(outputs, dim=1), state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = LSTMLayer(LSTMCell, 300, 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(64, 70, 300)\n",
    "h = (torch.zeros(64, 300),torch.zeros(64, 300))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 y,h1 = lstm(x,h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = lstm.cuda()\n",
    "x = x.cuda()\n",
    "h = (h[0].cuda(), h[1].cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = partial(lstm,x,h)\n",
    "time_fn(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit -n 10 time_fn(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With jit, we almost get to the CuDNN speed!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dropout"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to use the AWD-LSTM from [Stephen Merity et al.](https://arxiv.org/abs/1708.02182). First, we'll need all different kinds of dropouts. Dropout consists into replacing some coefficients by 0 with probability p. To ensure that the average of the weights remains constant, we apply a correction to the weights that aren't nullified of a factor `1/(1-p)` (think of what happens to the activations if you want to figure out why!)\n",
    "\n",
    "We usually apply dropout by drawing a mask that tells us which elements to nullify or not:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=7003)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def dropout_mask(x, sz, p):\n",
    "    return x.new(*sz).bernoulli_(1-p).div_(1-p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(10,10)\n",
    "mask = dropout_mask(x, (10,10), 0.5); mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once with have a dropout mask `mask`, applying the dropout to `x` is simply done by `x = x * mask`. We create our own dropout mask and don't rely on pytorch dropout because we do not want to nullify all the coefficients randomly: on the sequence dimension, we will want to have always replace the same positions by zero along the seq_len dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x*mask).std(),x.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inside a RNN, a tensor x will have three dimensions: bs, seq_len, vocab_size.  Recall that we want to consistently apply the dropout mask across the seq_len dimension, therefore, we create a dropout mask for the first and third dimension and broadcast it to the seq_len dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class RNNDropout(nn.Module):\n",
    "    def __init__(self, p=0.5):\n",
    "        super().__init__()\n",
    "        self.p=p\n",
    "\n",
    "    def forward(self, x):\n",
    "        if not self.training or self.p == 0.: return x\n",
    "        m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)\n",
    "        return x * m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp = RNNDropout(0.3)\n",
    "tst_input = torch.randn(3,3,7)\n",
    "tst_input, dp(tst_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "WeightDropout is dropout applied to the weights of the inner LSTM hidden to hidden matrix. This is a little hacky if we want to preserve the CuDNN speed and not reimplement the cell from scratch. We add a parameter that will contain the raw weights, and we replace the weight matrix in the LSTM at the beginning of the forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import warnings\n",
    "\n",
    "WEIGHT_HH = 'weight_hh_l0'\n",
    "\n",
    "class WeightDropout(nn.Module):\n",
    "    def __init__(self, module, weight_p=[0.], layer_names=[WEIGHT_HH]):\n",
    "        super().__init__()\n",
    "        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names\n",
    "        for layer in self.layer_names:\n",
    "            #Makes a copy of the weights of the selected layers.\n",
    "            w = getattr(self.module, layer)\n",
    "            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n",
    "            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)\n",
    "\n",
    "    def _setweights(self):\n",
    "        for layer in self.layer_names:\n",
    "            raw_w = getattr(self, f'{layer}_raw')\n",
    "            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)\n",
    "\n",
    "    def forward(self, *args):\n",
    "        self._setweights()\n",
    "        with warnings.catch_warnings():\n",
    "            #To avoid the warning that comes because the weights aren't flattened.\n",
    "            warnings.simplefilter(\"ignore\")\n",
    "            return self.module.forward(*args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module = nn.LSTM(5, 2)\n",
    "dp_module = WeightDropout(module, 0.4)\n",
    "getattr(dp_module.module, WEIGHT_HH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's at the beginning of a forward pass that the dropout is applied to the weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_input = torch.randn(4,20,5)\n",
    "h = (torch.zeros(1,20,2), torch.zeros(1,20,2))\n",
    "x,h = dp_module(tst_input,h)\n",
    "getattr(dp_module.module, WEIGHT_HH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "EmbeddingDropout applies dropout to full rows of the embedding matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class EmbeddingDropout(nn.Module):\n",
    "    \"Applies dropout in the embedding layer by zeroing out some elements of the embedding vector.\"\n",
    "    def __init__(self, emb, embed_p):\n",
    "        super().__init__()\n",
    "        self.emb,self.embed_p = emb,embed_p\n",
    "        self.pad_idx = self.emb.padding_idx\n",
    "        if self.pad_idx is None: self.pad_idx = -1\n",
    "\n",
    "    def forward(self, words, scale=None):\n",
    "        if self.training and self.embed_p != 0:\n",
    "            size = (self.emb.weight.size(0),1)\n",
    "            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)\n",
    "            masked_embed = self.emb.weight * mask\n",
    "        else: masked_embed = self.emb.weight\n",
    "        if scale: masked_embed.mul_(scale)\n",
    "        return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,\n",
    "                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = nn.Embedding(100, 7, padding_idx=1)\n",
    "enc_dp = EmbeddingDropout(enc, 0.5)\n",
    "tst_input = torch.randint(0,100,(8,))\n",
    "enc_dp(tst_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main model is a regular LSTM with several layers, but using all those kinds of dropouts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def to_detach(h):\n",
    "    \"Detaches `h` from its history.\"\n",
    "    return h.detach() if type(h) == torch.Tensor else tuple(to_detach(v) for v in h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AWD_LSTM(nn.Module):\n",
    "    \"AWD-LSTM inspired by https://arxiv.org/abs/1708.02182.\"\n",
    "    initrange=0.1\n",
    "\n",
    "    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token,\n",
    "                 hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):\n",
    "        super().__init__()\n",
    "        self.bs,self.emb_sz,self.n_hid,self.n_layers = 1,emb_sz,n_hid,n_layers\n",
    "        self.emb = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)\n",
    "        self.emb_dp = EmbeddingDropout(self.emb, embed_p)\n",
    "        self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz), 1,\n",
    "                             batch_first=True) for l in range(n_layers)]\n",
    "        self.rnns = nn.ModuleList([WeightDropout(rnn, weight_p) for rnn in self.rnns])\n",
    "        self.emb.weight.data.uniform_(-self.initrange, self.initrange)\n",
    "        self.input_dp = RNNDropout(input_p)\n",
    "        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])\n",
    "\n",
    "    def forward(self, input):\n",
    "        bs,sl = input.size()\n",
    "        if bs!=self.bs:\n",
    "            self.bs=bs\n",
    "            self.reset()\n",
    "        raw_output = self.input_dp(self.emb_dp(input))\n",
    "        new_hidden,raw_outputs,outputs = [],[],[]\n",
    "        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):\n",
    "            raw_output, new_h = rnn(raw_output, self.hidden[l])\n",
    "            new_hidden.append(new_h)\n",
    "            raw_outputs.append(raw_output)\n",
    "            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)\n",
    "            outputs.append(raw_output) \n",
    "        self.hidden = to_detach(new_hidden)\n",
    "        return raw_outputs, outputs\n",
    "\n",
    "    def _one_hidden(self, l):\n",
    "        \"Return one hidden state.\"\n",
    "        nh = self.n_hid if l != self.n_layers - 1 else self.emb_sz\n",
    "        return next(self.parameters()).new(1, self.bs, nh).zero_()\n",
    "\n",
    "    def reset(self):\n",
    "        \"Reset the hidden states.\"\n",
    "        self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On top of this, we will apply a linear decoder. It's often best to use the same matrix as the one for the embeddings in the weights of the decoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LinearDecoder(nn.Module):\n",
    "    def __init__(self, n_out, n_hid, output_p, tie_encoder=None, bias=True):\n",
    "        super().__init__()\n",
    "        self.output_dp = RNNDropout(output_p)\n",
    "        self.decoder = nn.Linear(n_hid, n_out, bias=bias)\n",
    "        if bias: self.decoder.bias.data.zero_()\n",
    "        if tie_encoder: self.decoder.weight = tie_encoder.weight\n",
    "        else: init.kaiming_uniform_(self.decoder.weight)\n",
    "\n",
    "    def forward(self, input):\n",
    "        raw_outputs, outputs = input\n",
    "        output = self.output_dp(outputs[-1]).contiguous()\n",
    "        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))\n",
    "        return decoded, raw_outputs, outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SequentialRNN(nn.Sequential):\n",
    "    \"A sequential module that passes the reset call to its children.\"\n",
    "    def reset(self):\n",
    "        for c in self.children():\n",
    "            if hasattr(c, 'reset'): c.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we stack all of this together!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_language_model(vocab_sz, emb_sz, n_hid, n_layers, pad_token, output_p=0.4, hidden_p=0.2, input_p=0.6, \n",
    "                       embed_p=0.1, weight_p=0.5, tie_weights=True, bias=True):\n",
    "    rnn_enc = AWD_LSTM(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token,\n",
    "                       hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)\n",
    "    enc = rnn_enc.emb if tie_weights else None\n",
    "    return SequentialRNN(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tok_pad = vocab.index(PAD)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can test this all works without throwing a bug."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_model = get_language_model(len(vocab), 300, 300, 2, tok_pad)\n",
    "tst_model = tst_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = tst_model(x.cuda())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We return three things to help with regularization: the true output (probabilities for each word), but also the activations of the encoder, with or without dropouts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded, raw_outputs, outputs = z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The decoded tensor is flattened to `bs * seq_len` by `len(vocab)`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded.size()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`raw_outputs` and `outputs` each contain the results of the intermediary layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(raw_outputs),len(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[o.size() for o in raw_outputs], [o.size() for o in outputs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Callbacks to train the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to add a few tweaks to train a language model: first we will clip the gradients. This is a classic technique that will allow us to use a higher learning rate by putting a maximum value on the norm of the gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GradientClipping(Callback):\n",
    "    def __init__(self, clip=None): self.clip = clip\n",
    "    def after_backward(self):\n",
    "        if self.clip:  nn.utils.clip_grad_norm_(self.run.model.parameters(), self.clip)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we add an `RNNTrainer` that will do four things:\n",
    "- change the output to make it contain only the `decoded` tensor (for the loss function) and store the `raw_outputs` and `outputs`\n",
    "- apply Activation Regularization (AR): we add to the loss an L2 penalty on the last activations of the AWD LSTM (with dropout applied)\n",
    "- apply Temporal Activation Regularization (TAR): we add to the loss an L2 penalty on the difference between two consecutive (in terms of words) raw outputs\n",
    "- trigger the shuffle of the LMDataset at the beginning of each epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class RNNTrainer(Callback):\n",
    "    def __init__(self, α, β): self.α,self.β = α,β\n",
    "    \n",
    "    def after_pred(self):\n",
    "        #Save the extra outputs for later and only returns the true output.\n",
    "        self.raw_out,self.out = self.pred[1],self.pred[2]\n",
    "        self.run.pred = self.pred[0]\n",
    "    \n",
    "    def after_loss(self):\n",
    "        #AR and TAR\n",
    "        if self.α != 0.:  self.run.loss += self.α * self.out[-1].float().pow(2).mean()\n",
    "        if self.β != 0.:\n",
    "            h = self.raw_out[-1]\n",
    "            if h.size(1)>1: self.run.loss += self.β * (h[:,1:] - h[:,:-1]).float().pow(2).mean()\n",
    "                \n",
    "    def begin_epoch(self):\n",
    "        #Shuffle the texts at the beginning of the epoch\n",
    "        if hasattr(self.dl.dataset, \"batchify\"): self.dl.dataset.batchify()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly we write a flattened version of the cross entropy loss and the accuracy metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def cross_entropy_flat(input, target):\n",
    "    bs,sl = target.size()\n",
    "    return F.cross_entropy(input.view(bs * sl, -1), target.view(bs * sl))\n",
    "\n",
    "def accuracy_flat(input, target):\n",
    "    bs,sl = target.size()\n",
    "    return accuracy(input.view(bs * sl, -1), target.view(bs * sl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_sz, nh, nl = 300, 300, 2\n",
    "model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, input_p=0.6, output_p=0.4, weight_p=0.5, \n",
    "                           embed_p=0.1, hidden_p=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbs = [partial(AvgStatsCallback,accuracy_flat),\n",
    "       CudaCallback, Recorder,\n",
    "       partial(GradientClipping, clip=0.1),\n",
    "       partial(RNNTrainer, α=2., β=1.),\n",
    "       ProgressCallback]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 12a_awd_lstm.ipynb to exp/nb_12a.py\r\n"
     ]
    }
   ],
   "source": [
    "!python notebook2script.py 12a_awd_lstm.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
